import torch
import numpy as np
import matplotlib.pyplot as plt
import os

initial_x = 4.0
initial_y = -1
initial_point = torch.tensor([initial_x, initial_y], requires_grad=True)

def heatmap_3D_gif(x, y, Z, trajectory1, trajectory2, args):
    X, Y = np.meshgrid(x, y)
    plt.figure(figsize=(8, 6))
    contour_plot = plt.contourf(X, Y, Z, levels=20, cmap='viridis', alpha=0.3)
    plt.colorbar(contour_plot, label='Function Value')
    path1 = np.array(trajectory1)
    plt.plot(path1[0], path1[1], color='red', linestyle='-', label=f'{args.trajectory1}', alpha=0.9)
    path2 = np.array(trajectory2)
    plt.plot(path2[0], path2[1], color='blue', linestyle='-', label=f'{args.trajectory2}', alpha=0.9)
    plt.title('line')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.legend()
    plt.savefig(f'../gif/{args.task_name}/{args.id}.png')  